from __future__ import absolute_import

from keras.layers import Input, Embedding, Dense, TimeDistributed
from keras.models import Model

from ..embeddings import get_embeddings_index, build_embedding_weights
from .sequence_encoders import SequenceEncoderBase


class SentenceModelFactory(object):
    def __init__(self, num_classes, token_index, max_sents, max_tokens,
                 embedding_type='glove.6B.100d', embedding_dims=100):
        """Creates a `SentenceModelFactory` instance for building various models that operate over
        (samples, max_sentences, max_tokens) input.

        Args:
            num_classes: The number of output classes.
            token_index: The dictionary of token and its corresponding integer index value.
            max_sents: The max number of sentences in a document.
            max_tokens: The max number of tokens in a sentence.
            embedding_type: The embedding type to use. Set to None to use random embeddings.
                (Default value: 'glove.6B.100d')
            embedding_dims: The number of embedding dims to use for representing a word. This argument will be ignored
                when `embedding_type` is set. (Default value: 100)
        """
        self.num_classes = num_classes
        self.token_index = token_index
        self.max_sents = max_sents
        self.max_tokens = max_tokens

        # This is required to make TimeDistributed(word_encoder_model) work.
        # TODO: Get rid of this restriction when https://github.com/fchollet/keras/issues/6917 resolves.
        if self.max_tokens is None:
            raise ValueError('`max_tokens` should be provided.')

        if embedding_type is not None:
            self.embeddings_index = get_embeddings_index(embedding_type)
            self.embedding_dims = self.embeddings_index.values()[0].shape[-1]
        else:
            self.embeddings_index = None
            self.embedding_dims = embedding_dims

    def build_model(self, token_encoder_model, sentence_encoder_model,
                    trainable_embeddings=True, output_activation='softmax'):
        """Builds a model that first encodes all words within sentences using `token_encoder_model`, followed by
        `sentence_encoder_model`.

        Args:
            token_encoder_model: An instance of `SequenceEncoderBase` for encoding tokens within sentences. This model
                will be applied across all sentences to create a sentence encoding.
            sentence_encoder_model: An instance of `SequenceEncoderBase` operating on sentence encoding generated by
                `token_encoder_model`. This encoding is then fed into a final `Dense` layer for classification.
            trainable_embeddings: Whether or not to fine tune embeddings.
            output_activation: The output activation to use. (Default value: 'softmax')
                Use:
                - `softmax` for binary or multi-class.
                - `sigmoid` for multi-label classification.
                - `linear` for regression output.

        Returns:
            The model output tensor.
        """
        if not isinstance(token_encoder_model, SequenceEncoderBase):
            raise ValueError("`token_encoder_model` should be an instance of `{}`".format(SequenceEncoderBase))
        if not isinstance(sentence_encoder_model, SequenceEncoderBase):
            raise ValueError("`sentence_encoder_model` should be an instance of `{}`".format(SequenceEncoderBase))

        if not sentence_encoder_model.allows_dynamic_length() and self.max_sents is None:
            raise ValueError("Sentence encoder model '{}' requires padding. "
                             "You need to provide `max_sents`")

        if self.embeddings_index is None:
            # The +1 is for unknown token index 0.
            embedding_layer = Embedding(len(self.token_index) + 1,
                                        self.embedding_dims,
                                        input_length=self.max_tokens,
                                        mask_zero=True,
                                        trainable=trainable_embeddings)
        else:
            embedding_layer = Embedding(len(self.token_index) + 1,
                                        self.embedding_dims,
                                        weights=[build_embedding_weights(self.token_index, self.embeddings_index)],
                                        input_length=self.max_tokens,
                                        mask_zero=True,
                                        trainable=trainable_embeddings)

        word_input = Input(shape=(self.max_tokens,), dtype='int32')
        x = embedding_layer(word_input)
        word_encoding = token_encoder_model(x)
        token_encoder_model = Model(word_input, word_encoding, name='word_encoder')

        doc_input = Input(shape=(self.max_sents, self.max_tokens), dtype='int32')
        sent_encoding = TimeDistributed(token_encoder_model)(doc_input)
        x = sentence_encoder_model(sent_encoding)

        x = Dense(self.num_classes, activation=output_activation)(x)
        return Model(doc_input, x)
